-
Notifications
You must be signed in to change notification settings - Fork 64
MoEGEMM as an extension of GroupGEMM #520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This comment was marked as off-topic.
This comment was marked as off-topic.
@sanchitintel how is the progress of the PR? |
ef117e2
to
1545982
Compare
Hi @airMeng, This PR doesn't have updated code with performance optimizations, which I otherwise have locally. It does have the updated API interface in the example, though. However, if @Antonyvance & the cutlass team wouldn't want to have MoE GEMM as a separate kernel, then it's better to port it to the Can you please explain why you asked? Thanks |
Background
When multiple GEMMs are to be computed, each with its own canonical
A
,B
,C
,D
matrices, GroupGEMM is useful for ensuring high GPU utilization & preventing launch overhead that'd otherwise occur for multiple GEMM kernel launches. In cutlass, the vanilla GroupGEMM uses a persistent kernel approach - the number of workgroups launched are equal to the number of Xe cores, and they loop through until they have work, (in this case, work, is the mainloop to compute one of the output tiles of any one of the GEMMs we try to compute with the GroupGEMM API).For Mixture of Experts used in Deep Learning models such as LLMs, the MoE GEMM use-case is something like this - each
expert
(corresponding to agroup
) has an associatedweight
sizedN * K
, which essentially a column-majorB
matrix. All theB
matrices are contiguous w.r.t. each other, i.e. their total size isnum_groups * N * K
.N, K
are compile-time constants.M
for each group is variable. AllA
matrices are also contiguous w.r.t. each other. Each set of tokens routed to an expert makes up theA
matrix for that group.MoEGEMM
seems to be a natural candidate for leveraging GroupGEMM.The problem
The cutlass GroupGEMM API is generic in that it requires pointers of
A
,B
,C
,D
tensors pertaining to each group.For launching the kernel, the CPU needs to provide a array of these GPU pointers (that array is also on the GPU).
However, for practical use-cases such as Mixture of Experts (each GroupGEMM
group
corresponds to oneMoEexpert
), such lists can't be conveniently pre-computed in advance (it's indeed possible to create it at the beginning of the kernel, and then synchronize across all workgroups, but that code can't be a part of generic Group GEMM).Solution proposed in this PR
Provide only the base
A
,B
,C
,D
pointers, and also passN
,K
, so that the canonicalA
,B
,C
,D
matrices' pointers for each group can be computed on-the-fly (a prefix sum algorithm to compute a cumulative sum ofM
might help but based on our experimentation, it doesn't seem to make much difference, as smallM
case is memory-bound, anyway).To have minimal changes from the existing code, pass lists sized one instead of lists with size equal to the number of groups, as otherwise happens in the default case.
The PR adds a new kernel & a tile scheduler for MoEGEMM, while reusing existing MMA & epilogue collectives (but with modified code for
A
,B
,C
,D
pointer computation).We could instead add a template parameter to make these changes in the existing kernels and also use
if constexpr
to separate it from the default GroupGEMM. While the current implementation in this PR introduces duplication, the alternative would make the code messier.Performance
With small
M
dimension for eachGEMM problem
, the performance is worse than that of largeM
dimension due to lower arithmetic intensity in the former case, but it's better than launching a separate kernel for each GEMM problem.Caveat
The example just portrays one way to use the API.
Also, it has mostly been copy-pasted from an existing example, so it can be revised further.